{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# WordPiece tokenization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = [\n",
    "    \"This is the Hugging Face course.\",\n",
    "    \"This chapter is about tokenization.\",\n",
    "    \"This section shows several tokenizer algorithms.\",\n",
    "    \"Hopefully, you will be able to understand how they are trained and generate tokens.\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(\n",
       "    int, {'This': 3, 'is': 2, 'the': 1, 'Hugging': 1, 'Face': 1, 'Course': 1, '.': 4, 'chapter': 1, 'about': 1,\n",
       "    'tokenization': 1, 'section': 1, 'shows': 1, 'several': 1, 'tokenizer': 1, 'algorithms': 1, 'Hopefully': 1,\n",
       "    ',': 1, 'you': 1, 'will': 1, 'be': 1, 'able': 1, 'to': 1, 'understand': 1, 'how': 1, 'they': 1, 'are': 1,\n",
       "    'trained': 1, 'and': 1, 'generate': 1, 'tokens': 1})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "word_freqs = defaultdict(int)\n",
    "for text in corpus:\n",
    "    words_with_offsets = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    new_words = [word for word, offset in words_with_offsets]\n",
    "    for word in new_words:\n",
    "        word_freqs[word] += 1\n",
    "\n",
    "word_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k', '##l', '##m', '##n', '##o', '##p', '##r', '##s',\n",
       " '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H', 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u',\n",
       " 'w', 'y']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alphabet = []\n",
    "for word in word_freqs.keys():\n",
    "    if word[0] not in alphabet:\n",
    "        alphabet.append(word[0])\n",
    "    for letter in word[1:]:\n",
    "        if f\"##{letter}\" not in alphabet:\n",
    "            alphabet.append(f\"##{letter}\")\n",
    "\n",
    "alphabet.sort()\n",
    "alphabet\n",
    "\n",
    "print(alphabet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = [\"[PAD]\", \"[UNK]\", \"[CLS]\", \"[SEP]\", \"[MASK]\"] + alphabet.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = {\n",
    "    word: [c if i == 0 else f\"##{c}\" for i, c in enumerate(word)]\n",
    "    for word in word_freqs.keys()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_pair_scores(splits):\n",
    "    letter_freqs = defaultdict(int)\n",
    "    pair_freqs = defaultdict(int)\n",
    "    for word, freq in word_freqs.items():\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            letter_freqs[split[0]] += freq\n",
    "            continue\n",
    "        for i in range(len(split) - 1):\n",
    "            pair = (split[i], split[i + 1])\n",
    "            letter_freqs[split[i]] += freq\n",
    "            pair_freqs[pair] += freq\n",
    "        letter_freqs[split[-1]] += freq\n",
    "\n",
    "    scores = {\n",
    "        pair: freq / (letter_freqs[pair[0]] * letter_freqs[pair[1]])\n",
    "        for pair, freq in pair_freqs.items()\n",
    "    }\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('T', '##h'): 0.125\n",
       "('##h', '##i'): 0.03409090909090909\n",
       "('##i', '##s'): 0.02727272727272727\n",
       "('i', '##s'): 0.1\n",
       "('t', '##h'): 0.03571428571428571\n",
       "('##h', '##e'): 0.011904761904761904"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pair_scores = compute_pair_scores(splits)\n",
    "for i, key in enumerate(pair_scores.keys()):\n",
    "    print(f\"{key}: {pair_scores[key]}\")\n",
    "    if i >= 5:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('a', '##b') 0.2"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_pair = \"\"\n",
    "max_score = None\n",
    "for pair, score in pair_scores.items():\n",
    "    if max_score is None or max_score < score:\n",
    "        best_pair = pair\n",
    "        max_score = score\n",
    "\n",
    "print(best_pair, max_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab.append(\"ab\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_pair(a, b, splits):\n",
    "    for word in word_freqs:\n",
    "        split = splits[word]\n",
    "        if len(split) == 1:\n",
    "            continue\n",
    "        i = 0\n",
    "        while i < len(split) - 1:\n",
    "            if split[i] == a and split[i + 1] == b:\n",
    "                merge = a + b[2:] if b.startswith(\"##\") else a + b\n",
    "                split = split[:i] + [merge] + split[i + 2 :]\n",
    "            else:\n",
    "                i += 1\n",
    "        splits[word] = split\n",
    "    return splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['ab', '##o', '##u', '##t']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "splits = merge_pair(\"a\", \"##b\", splits)\n",
    "splits[\"about\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 70\n",
    "while len(vocab) < vocab_size:\n",
    "    scores = compute_pair_scores(splits)\n",
    "    best_pair, max_score = \"\", None\n",
    "    for pair, score in scores.items():\n",
    "        if max_score is None or max_score < score:\n",
    "            best_pair = pair\n",
    "            max_score = score\n",
    "    splits = merge_pair(*best_pair, splits)\n",
    "    new_token = (\n",
    "        best_pair[0] + best_pair[1][2:]\n",
    "        if best_pair[1].startswith(\"##\")\n",
    "        else best_pair[0] + best_pair[1]\n",
    "    )\n",
    "    vocab.append(new_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', '##a', '##b', '##c', '##d', '##e', '##f', '##g', '##h', '##i', '##k',\n",
       " '##l', '##m', '##n', '##o', '##p', '##r', '##s', '##t', '##u', '##v', '##w', '##y', '##z', ',', '.', 'C', 'F', 'H',\n",
       " 'T', 'a', 'b', 'c', 'g', 'h', 'i', 's', 't', 'u', 'w', 'y', '##fu', 'Fa', 'Fac', '##ct', '##ful', '##full', '##fully',\n",
       " 'Th', 'ch', '##hm', 'cha', 'chap', 'chapt', '##thm', 'Hu', 'Hug', 'Hugg', 'sh', 'th', 'is', '##thms', '##za', '##zat',\n",
       " '##ut']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_word(word):\n",
    "    tokens = []\n",
    "    while len(word) > 0:\n",
    "        i = len(word)\n",
    "        while i > 0 and word[:i] not in vocab:\n",
    "            i -= 1\n",
    "        if i == 0:\n",
    "            return [\"[UNK]\"]\n",
    "        tokens.append(word[:i])\n",
    "        word = word[i:]\n",
    "        if len(word) > 0:\n",
    "            word = f\"##{word}\"\n",
    "    return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Hugg', '##i', '##n', '##g']\n",
       "['[UNK]']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(encode_word(\"Hugging\"))\n",
    "print(encode_word(\"HOgging\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize(text):\n",
    "    pre_tokenize_result = tokenizer._tokenizer.pre_tokenizer.pre_tokenize_str(text)\n",
    "    pre_tokenized_text = [word for word, offset in pre_tokenize_result]\n",
    "    encoded_words = [encode_word(word) for word in pre_tokenized_text]\n",
    "    return sum(encoded_words, [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Th', '##i', '##s', 'is', 'th', '##e', 'Hugg', '##i', '##n', '##g', 'Fac', '##e', 'c', '##o', '##u', '##r', '##s',\n",
       " '##e', '[UNK]']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenize(\"This is the Hugging Face course!\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "WordPiece tokenization",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
